{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [13:44:39] Enabling RDKit 2019.09.1 jupyter extensions\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fa66808cf90>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch_geometric.data import DataLoader\n",
    "import torch_geometric\n",
    "import torch.distributions as D\n",
    "import matplotlib.pyplot as plt\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, Draw, rdFMCS, rdMolTransforms\n",
    "from rdkit.Chem.rdMolAlign import AlignMol\n",
    "from rdkit.Chem import PandasTools\n",
    "from rdkit import rdBase\n",
    "import glob\n",
    "import os\n",
    "\n",
    "import deepdock \n",
    "from deepdock.utils.distributions import *\n",
    "from deepdock.utils.data import *\n",
    "from deepdock.models import *\n",
    "\n",
    "%matplotlib inline\n",
    "np.random.seed(123)\n",
    "torch.cuda.manual_seed_all(123)\n",
    "torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "#device = 'cpu'\n",
    "\n",
    "ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10)\n",
    "target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10)\n",
    "model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to(device)\n",
    "\n",
    "checkpoint = torch.load(deepdock.__path__[0]+'/../Trained_models/DeepDock_pdbbindv2019_13K_minTestLoss.chk', map_location=torch.device(device))\n",
    "model.load_state_dict(checkpoint['model_state_dict'])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Complexes from pdbBind: 285\n"
     ]
    }
   ],
   "source": [
    "db_complex = PDBbind_complex_dataset(data_path=deepdock.__path__[0]+'/../data/dataset_CASF-2016_285.tar', \n",
    "                                     min_target_nodes=None, max_ligand_nodes=None)\n",
    "print('Complexes from pdbBind:', len(db_complex))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class input_dataset(Dataset):\n",
    "    def __init__(self, mols, target_mesh, labels=None, transform=None, pre_transform=None):\n",
    "        super(input_dataset, self).__init__()\n",
    "        \n",
    "        self.mols = [from_networkx(mol2graph.mol_to_nx(m)) for m in mols]\n",
    "        self.target_mesh = target_mesh\n",
    "        self.labels = labels\n",
    "        if labels is None:\n",
    "            self.labels = range(len(self.mols))\n",
    "        \n",
    "    def len(self):\n",
    "        return len(self.mols)\n",
    "\n",
    "    def get(self, idx):\n",
    "        return self.mols[idx], self.target_mesh, self.labels[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6min 1s, sys: 7.56 s, total: 6min 9s\n",
      "Wall time: 3min 50s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from torch_scatter import scatter_add\n",
    "results = []\n",
    "\n",
    "for target_data in db_complex:\n",
    "    model.eval()\n",
    "    ligand, target, _, pdbid = target_data\n",
    "    decoys = Mol2MolSupplier(file=deepdock.__path__[0]+'/../data/CASF-2016/decoys_docking/'+pdbid+'_decoys.mol2', \n",
    "                             sanitize=False, cleanupSubstructures=False)\n",
    "    decoy_names = [m.GetProp('Name') for m in decoys]\n",
    "    db_decoys = input_dataset(decoys, target, decoy_names)\n",
    "    loader_decoys = DataLoader(db_decoys, batch_size=20, shuffle=False)\n",
    "    #print(pdbid)\n",
    "\n",
    "    for data in loader_decoys:\n",
    "        decoy, target, cpd_name = data\n",
    "        decoy, target = decoy.to(device), target.to(device)\n",
    "        pi, sigma, mu, dist, atom_types, bond_types, batch = model(decoy, target)\n",
    "\n",
    "        normal = Normal(mu, sigma)\n",
    "        logprob = normal.log_prob(dist.expand_as(normal.loc))\n",
    "        logprob += torch.log(pi)\n",
    "        prob = logprob.exp().sum(1)\n",
    "        prob_all = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "    \n",
    "        prob[torch.where(dist > 10)[0]] = 0.\n",
    "        prob_10 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "        \n",
    "        prob[torch.where(dist > 7)[0]] = 0.\n",
    "        prob_7 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "\n",
    "        prob[torch.where(dist > 5)[0]] = 0.\n",
    "        prob_5 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "\n",
    "        prob[torch.where(dist > 3)[0]] = 0.\n",
    "        prob_3 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "        \n",
    "        prob = torch.stack([prob_3, prob_5, prob_7, prob_10, prob_all],dim=1)\n",
    "        #print(pdbid, cpd_name, prob_all.cpu().detach().numpy())\n",
    "        results.append(np.concatenate([np.expand_dims(np.repeat(pdbid, len(cpd_name)), axis=1), \n",
    "                                       np.expand_dims(cpd_name, axis=1), \n",
    "                                       prob.cpu().detach().numpy()], axis=1))\n",
    "      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(22492, 7)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PDB_ID</th>\n",
       "      <th>Cpd_Name</th>\n",
       "      <th>Score_3A</th>\n",
       "      <th>Score_5A</th>\n",
       "      <th>Score_7A</th>\n",
       "      <th>Score_10A</th>\n",
       "      <th>Score_all</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4k18</td>\n",
       "      <td>4k18_100</td>\n",
       "      <td>41.7124709276148</td>\n",
       "      <td>355.00999393874434</td>\n",
       "      <td>1234.1988897433216</td>\n",
       "      <td>1265.8331969710105</td>\n",
       "      <td>1265.8335978696152</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4k18</td>\n",
       "      <td>4k18_105</td>\n",
       "      <td>36.50107725111608</td>\n",
       "      <td>320.5826389497487</td>\n",
       "      <td>1098.7196528481893</td>\n",
       "      <td>1128.4331804224062</td>\n",
       "      <td>1128.4344001617458</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4k18</td>\n",
       "      <td>4k18_107</td>\n",
       "      <td>56.13413997037191</td>\n",
       "      <td>406.6658214522255</td>\n",
       "      <td>1261.5476637591444</td>\n",
       "      <td>1298.835192687907</td>\n",
       "      <td>1298.835338794891</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4k18</td>\n",
       "      <td>4k18_118</td>\n",
       "      <td>44.23784542273988</td>\n",
       "      <td>336.9721288402966</td>\n",
       "      <td>1252.596765731967</td>\n",
       "      <td>1287.006223774616</td>\n",
       "      <td>1287.0066010553867</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4k18</td>\n",
       "      <td>4k18_122</td>\n",
       "      <td>42.72467596870927</td>\n",
       "      <td>342.7847157990617</td>\n",
       "      <td>1211.367931058725</td>\n",
       "      <td>1252.029400636168</td>\n",
       "      <td>1252.029755610901</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  PDB_ID  Cpd_Name           Score_3A            Score_5A            Score_7A  \\\n",
       "0   4k18  4k18_100   41.7124709276148  355.00999393874434  1234.1988897433216   \n",
       "1   4k18  4k18_105  36.50107725111608   320.5826389497487  1098.7196528481893   \n",
       "2   4k18  4k18_107  56.13413997037191   406.6658214522255  1261.5476637591444   \n",
       "3   4k18  4k18_118  44.23784542273988   336.9721288402966   1252.596765731967   \n",
       "4   4k18  4k18_122  42.72467596870927   342.7847157990617   1211.367931058725   \n",
       "\n",
       "            Score_10A           Score_all  \n",
       "0  1265.8331969710105  1265.8335978696152  \n",
       "1  1128.4331804224062  1128.4344001617458  \n",
       "2   1298.835192687907   1298.835338794891  \n",
       "3   1287.006223774616  1287.0066010553867  \n",
       "4   1252.029400636168   1252.029755610901  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = np.concatenate(results, axis=0)\n",
    "results = pd.DataFrame(np.asarray(results), columns=['PDB_ID', 'Cpd_Name', 'Score_3A', 'Score_5A', 'Score_7A', 'Score_10A', 'Score_all'])\n",
    "results.to_csv('Score_decoys_docking_CASF2016.csv', index=False)\n",
    "print(results.shape)\n",
    "results.head() "
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('Score_decoys_docking_CASF2016.csv')\n",
    "pdbids = df.PDB_ID.unique()\n",
    "print(len(pdbids))\n",
    "\n",
    "for pdbid in pdbids:\n",
    "\n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_3A']]\n",
    "    df1.columns= ['#code', 'score']\n",
    "    df1.to_csv('DockingPower_DeepDock_3A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_5A']]\n",
    "    df1.columns= ['#code', 'score']\n",
    "    df1.to_csv('DockingPower_DeepDock_5A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_7A']]\n",
    "    df1.columns= ['#code', 'score']\n",
    "    df1.to_csv('DockingPower_DeepDock_7A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_10A']]\n",
    "    df1.columns= ['#code', 'score']\n",
    "    df1.to_csv('DockingPower_DeepDock_10A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_all']]\n",
    "    df1.columns= ['#code', 'score']\n",
    "    df1.to_csv('DockingPower_DeepDock_all/scores/'+pdbid+'_score.dat', index=False, sep='\\t')"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "python /data/CASF-2016/power_docking/docking_power.py -c /data/CASF-2016/power_docking/CoreSet.dat -s scores -r /data/CASF-2016/decoys_docking/ -p 'positive' -l 2 -o 'DeepDock' > DockingPower_DeepDock_all.out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
